""" Diffuser (Diffusion-based Planner) Implementaiton """
from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import jax
import jax.numpy as jnp
import haiku as hk
import optax

from sb3_jax.common.policies import BasePolicy
from sb3_jax.common.norm_layers import BaseNormLayer
from sb3_jax.common.jax_layers import BaseFeaturesExtractor, FlattenExtractor
from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.utils import get_dummy_obs, get_dummy_act
from sb3_jax.du.policies import DiffusionBetaScheduler

from diffgro.utils.utils import print_b 
from diffgro.common.models.utils import apply_cond
from diffgro.common.models.helpers import MLP
from diffgro.common.models.diffusion import UNetDiffusion, Diffusion
from diffgro.diffgro.functions import calculate_grad
from diffgro.diffgro.policies import apply_mask


class Actor(BasePolicy):
    """ actor for diffuser policy """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        domain: str = 'short',
        # embedding
        horizon: int = 8,       # planning horizon
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embedding dimension
        # diffusion
        n_denoise: int = 20,    # denoising timestep
        cf_weight: float = 1.0, # classifier-free guidance weight
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'linear', # denoising scheduler
        seed: int = 1,
    ):
        super(Actor, self).__init__(
            observation_space,
            action_space,
            squash_output=False,
            seed=seed,
        )

        self.net_arch = net_arch
        self.activation_fn = activation_fn
        
        self.domain = domain
        print_b(f"Setting doamin as {self.domain}")
        # embedding
        self.horizon = horizon
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim

        # diffusion
        self.n_denoise = n_denoise
        self.cf_weight = cf_weight
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler
        self.ddpm_dict = DiffusionBetaScheduler(None, None, n_denoise, beta_scheduler).schedule()

        # misc
        self.obs_dim = get_flattened_obs_dim(self.observation_space)
        self.act_dim = get_act_dim(self.action_space)
        self.out_dim = self.obs_dim + self.act_dim 

        self._build()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()
        return data

    def _build_act(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        unet = UNetDiffusion(
            horizon=self.horizon,
            emb_dim=self.emb_dim,
            out_dim=self.out_dim,
            dim_mults=(1,4,8),
            attention=False,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn
        )
        return Diffusion(
            diffusion=unet,
            n_denoise=self.n_denoise,
            ddpm_dict=self.ddpm_dict,
            guidance_weight=self.cf_weight,
            predict_epsilon=self.predict_epsilon,
            denoise_type='ddpm',
        )

    def _build(self) -> None:
        # dummy inputs
        dummy_obs, dummy_act = get_dummy_obs(self.observation_space), get_dummy_act(self.action_space)
        dummy_obs_stack = jnp.repeat(dummy_obs, self.horizon, axis=0).reshape(1, self.horizon, -1) # stacked observation
        dummy_act_stack = jnp.repeat(dummy_act, self.horizon, axis=0).reshape(1, self.horizon, -1) # stacked action
        dummy_traj = jnp.concatenate((dummy_obs_stack, dummy_act_stack), axis=-1)   # trajectory
        dummy_lang = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))  # skill embedding
        dummy_skill = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))  # skill embedding
        dummy_t = jnp.array([[1.]])

        def fn_act(x_t: jax.Array, batch_dict: Dict[str, jax.Array], t: jax.Array, denoise: bool, deterministic: bool):
            batch_keys = ["lang"] if self.domain == 'short' else ["lang", "skill"]
            act = self._build_act(batch_keys=batch_keys)
            return act(x_t, batch_dict, t, denoise, deterministic)
        params, self.pi = hk.transform(fn_act)
        batch_dict = {"lang": dummy_lang, "skill": dummy_skill}
        self.params = params(next(self.rng), dummy_traj, batch_dict, dummy_t, denoise=False, deterministic=False)

    @partial(jax.jit, static_argnums=(0,4,5))
    def _pi(
        self,
        x_t: jax.Array, 
        batch_dict: Dict[str, jax.Array],
        t: jax.Array, 
        denoise: bool, 
        deterministic: bool, 
        params: hk.Params, 
        rng=None
    ) -> Tuple[Tuple[jax.Array], Dict[str, jax.Array]]:
        return self.pi(params, rng, x_t, batch_dict, t, denoise, deterministic)
    
    def _predict(
        self,
        x_t: jax.Array,
        lang: jax.Array,
        t: int,
        skill: jax.Array = None,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        # return: eps, info
        batch_dict = {"lang": lang, "skill": skill}
        ts = jnp.full((x_t.shape[0], 1), t)
        eps, info = self._pi(x_t, batch_dict, ts, False, deterministic, self.params, next(self.rng))
        return eps, info

    # one denoise timestep prediction with guidance
    @partial(jax.jit, static_argnums=(0,4))
    def _sample(
        self,
        x_t: jax.Array,
        eps: jax.Array,
        t: int,
        deterministic: bool,
        rng=None,
    ) -> jax.Array:
        batch_size = x_t.shape[0]
        noise = jax.random.normal(rng, shape=(batch_size, self.horizon, self.out_dim)) if not deterministic else 0.
       
        if self.predict_epsilon:
            x_t = self.ddpm_dict.oneover_sqrta[t] * (x_t - self.ddpm_dict.ma_over_sqrtmab_inv[t] * eps) \
                    + self.ddpm_dict.sqrt_beta_t[t] * noise
        else:
            x_t = self.ddpm_dict.posterior_mean_coef1[t] * eps + self.ddpm_dict.posterior_mean_coef2[t] * x_t \
                    + jnp.exp(0.5 * self.ddpm_dict.posterior_log_beta[t]) * noise
        return x_t

    def _denoise(
        self,
        cond: jax.Array,
        lang: jax.Array,
        skill: jax.Array = None,
        mask: jax.Array = None,
        delta: float = 0.1,
        guide_fn: Callable = None,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_size = cond.shape[0]

        x_t = jax.random.normal(next(self.rng), shape=(batch_size, self.horizon, self.out_dim))
        if mask is None:
            x_t = apply_cond(x_t, cond)
        else:
            x_t = apply_mask(mask, x_t, cond)
    
        for t in range(self.n_denoise, 0, -1):
            eps, _ = self._predict(x_t, lang, t, skill, deterministic)
            original_eps = eps

            if (guide_fn is not None) and (t <= self.n_denoise - 2):
                # calculate gradient
                grad, grad_info = calculate_grad(guide_fn, eps, self.obs_dim)
                # gradient scaling
                loss = grad_info['loss']
                count = 0
                retry = 0 

                if loss < 0.0:
                    loss = -loss

                while True:
                    retry += 1
                    if (loss <= 1.0 and loss >= 0.1) or loss == 0.0:
                        break
                    if loss > 1.0:
                        loss /= 10
                        count -= 1
                    if loss < 0.1:
                        loss *= 10
                        count += 1

                try: 
                    grad = grad * (10 ** count)
                except:
                    print(count)
                    print(grad_info['loss'])
                    exit()
                # apply masking
                if mask is not None:
                    grad = (1 - mask) * grad
                # apply gradient
                eps = eps - delta * grad  # jnp.exp(self.ddpm_dict.posterior_log_beta[t])
                    
            x_t = self._sample(x_t, eps, t, deterministic, next(self.rng))
            if mask is None:
                x_t = apply_cond(x_t, cond)
            else:
                x_t = apply_mask(mask, x_t, cond)
        return x_t, {}

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[diffuser/actor]: loading params")
        self.params = params["pi_params"]

    
class DiffuserPlannerPolicy(BasePolicy):
    """ policy class for diffuser """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        domain: str = 'short',
        # embedding
        horizon: int = 8,       # planning horizon
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embedding dimension
        # diffusion
        n_denoise: int = 20,    # denoising timestep
        cf_weight: float = 1.0, # diffusion classifier-free guidance weight
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'linear', # denoising scheduler
        # others
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Callable = optax.adamw,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        normalization_class: Type[BaseNormLayer] = None,
        normalization_kwargs: Optional[Dict[str, Any]] = None,
        seed: int = 1,
    ):
        super(DiffuserPlannerPolicy, self).__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            normalization_class=normalization_class,
            normalization_kwargs=normalization_kwargs,
            squash_output=squash_output,
            seed=seed,
        )

        if net_arch is None:
            net_arch = dict(act=(1,4,8))
        self.act_arch = net_arch['act']
        self.activation_fn = activation_fn

        self.domain = domain
        assert self.domain in ['short', 'long'], 'Domain should be either short or long'
        self.horizon = horizon
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim

        self.n_denoise = n_denoise
        self.cf_weight = cf_weight
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler

        self.act_kwargs = {
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "activation_fn": self.activation_fn,
            "net_arch": self.act_arch,
            "domain": domain,
            "horizon": horizon,
            "skill_dim": skill_dim,
            "emb_dim": emb_dim,
            "n_denoise": n_denoise,
            "cf_weight": cf_weight,
            "predict_epsilon": predict_epsilon,
            "beta_scheduler": beta_scheduler,
            "seed": seed,
        }

        self._build(lr_schedule)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                observation_space=self.observation_space,
                action_space=self.action_space, 
                horizon=self.horizon,
                skill_dim=self.skill_dim,
                emb_dim=self.emb_dim,
                n_denoise=self.n_denoise,
                cf_weight=self.cf_weight,
                predict_epsilon=self.predict_epsilon,
                beta_scheduler=self.beta_scheduler,
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
                features_extractor_class=self.features_extractor_class,
                features_extractor_kwargs=self.features_extractor_kwargs,
                normalization_class=self.normalization_class,
                normalization_kwargs=self.normalization_kwargs,
            )
        )
        return data

    def _build(self, lr_schedule: Tuple[float]) -> None:
        if self.normalization_class is not None:
            self.normalization_layer = self.normalization_class(self.observation_space.shape, **self.normalization_kwargs)

        self.act = self.make_act()
        self.act.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.act.optim_state = self.act.optim.init(self.act.params)

    def make_act(self) -> Actor:
        return Actor(**self.act_kwargs)

    def _predict(
        self,
        cond: jax.Array, # observation
        lang: jax.Array,
        skill: jax.Array = None,
        mask: jax.Array = None,
        delta: float = 0.1,
        guide_fn: jax.Array = None,
        deterministic: bool = True,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        if len(cond.shape) == 2:
            cond = self.preprocess(cond, training=False)
        else:
            obs, act = cond[:,:,:self.act.obs_dim], cond[:,:,-self.act.act_dim:]
            obs = self.preprocess(obs.reshape(-1, self.act.obs_dim), training=False)
            obs = obs.reshape(-1, self.act.horizon, self.act.obs_dim)
            cond = jnp.concatenate((obs, act), axis=-1)
        return self.act._denoise(cond, lang, skill, mask, delta, guide_fn, deterministic) 
